#include <torch/extension.h>

// #include <vector>

// CUDA forward declarations

void sparse_matmul_cuda_forward_launch(
    torch::Tensor x,
    torch::Tensor y,
    torch::Tensor index,
    torch::Tensor output
    );

void sparse_matmul_cuda_backward_launch(
    torch::Tensor grad_output,
    torch::Tensor grad_x,
    torch::Tensor grad_y,
    torch::Tensor x,
    torch::Tensor y,
    torch::Tensor mask);

// C++ interface

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void sparse_matmul_cuda_forward(
    torch::Tensor x,
    torch::Tensor y,
    torch::Tensor index,
    torch::Tensor output) {
  CHECK_INPUT(x);
  CHECK_INPUT(y);
  CHECK_INPUT(index);
  CHECK_INPUT(output);

  sparse_matmul_cuda_forward_launch(x, y, index, output);
}

void sparse_matmul_cuda_backward(
    torch::Tensor grad_output,
    torch::Tensor grad_x,
    torch::Tensor grad_y,
    torch::Tensor x,
    torch::Tensor y,
    torch::Tensor mask) {
  CHECK_INPUT(grad_output);
  CHECK_INPUT(grad_x);
  CHECK_INPUT(grad_y);
  CHECK_INPUT(x);
  CHECK_INPUT(y);
  CHECK_INPUT(mask);

  sparse_matmul_cuda_backward_launch(grad_output, grad_x, grad_y, x, y, mask);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &sparse_matmul_cuda_forward, "sparse matmul forward (CUDA)");
  m.def("backward", &sparse_matmul_cuda_backward, "sparse matmul backward (CUDA)");
}
